import os
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
import torch
import torch.nn as nn
import pytorch_lightning as pl
from dataset import EEGDataModuleBalanced
from argparse import ArgumentParser
import torch.nn.functional as F
from argparse import Namespace
import logging
import datetime
import os.path as op
from pytorch_lightning.callbacks import ModelCheckpoint
import json
from model import CNN_baseline

class EEG2EMB_CTS(pl.LightningModule):
    def __init__(self, args: Namespace, logger):
        super().__init__()
        self.args = args
        self.model = CNN_baseline()
        self.txtlogger = logger

    def info_nce_loss(self, features, image_feature, labels):
        features = F.normalize(features, dim=1)
        image_feature = F.normalize(image_feature, dim=1)
        similarity_matrix = torch.matmul(features, image_feature.T)
        # discard the main diagonal from both: labels and similarities matrix
        mask_pos = torch.eye(labels.shape[0], dtype=torch.bool).to(self.device)
        positives = similarity_matrix[mask_pos].view(labels.shape[0], -1)
        # select only the negatives
        negatives = similarity_matrix[~mask_pos].view(similarity_matrix.shape[0], -1)
        logits = torch.cat([positives, negatives], dim=1)
        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.device)
        logits = logits / self.args.temp
        loss = F.cross_entropy(logits, labels)
        return loss


    def rank_acc(self, x, y, label):
        # x: (n, d)
        # y: (m, d)
        # label: (n)
        x = F.normalize(x, dim=1)
        y = F.normalize(y, dim=1)
        sim = torch.matmul(x, y.T)
        # rank accuracy
        _, idx = sim.sort(dim=1, descending=True)
        rank = torch.where(idx == label.unsqueeze(1))[1]
        top1 = (rank == 0).float().mean()
        top5 = (rank < 5).float().mean()
        rank_acc = (len(y) - 1 - rank)/(len(y) - 1)
        rank_acc = rank_acc.mean()
        return top1, top5, rank_acc

    def rank_acc_logits(self, logits, label):
        sim = logits
        # rank accuracy
        _, idx = sim.sort(dim=1, descending=True)
        rank = torch.where(idx == label.unsqueeze(1))[1]
        top1 = (rank == 0).float().mean()
        top5 = (rank < 5).float().mean()
        rank_acc = (sim.shape[1]- 1 - rank)/(sim.shape[1] - 1)
        rank_acc = rank_acc.mean()
        return top1, top5, rank_acc

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        eeg, label, image_features = batch
        features = self(eeg)
        loss = self.info_nce_loss(features, image_features, label)
        top1, top5, rank_acc = self.rank_acc(features, image_features, label)
        self.log('loss', loss, prog_bar=True)
        self.log('top1', top1)
        self.log('top5', top5)
        self.log('rank_acc', rank_acc)
        return loss

    def on_train_epoch_end(self) -> None:
        avg_rank_acc = self.trainer.callback_metrics['rank_acc'].item()
        avg_top1 = self.trainer.callback_metrics['top1'].item()
        avg_top5 = self.trainer.callback_metrics['top5'].item()
        self.txtlogger.info(f"Epoch {self.current_epoch}, "
                            f"rank_acc: {avg_rank_acc:.4f}, "
                            f"top1: {avg_top1:.4f}, "
                            f"top5: {avg_top5:.4f}")


    def on_validation_epoch_start(self) -> None:
        pass


    def validation_step(self, batch, batch_idx):
        eeg, label, image_features = batch
        output = self.model(eeg)
        ankor = self.trainer.datamodule.train_dataset.image_feat_cls
        top1, top5, rank_acc = self.rank_acc(output, ankor, label)
        self.log('val_top1', top1)
        self.log('val_top5', top5)
        self.log('val_rank_acc', rank_acc)



    def on_validation_epoch_end(self):
        val_top1 = self.trainer.callback_metrics['val_top1'].item()
        val_top5 = self.trainer.callback_metrics['val_top5'].item()
        val_rank_acc = self.trainer.callback_metrics['val_rank_acc'].item()
        self.txtlogger.info(f"Epoch {self.current_epoch}, "
                            f"val_rank_acc: {val_rank_acc:.4f}, "
                            f"val_top1: {val_top1:.4f}, "
                            f"val_top5: {val_top5:.4f}")

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.args.lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.8)
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'epoch',
                'frequency': 1,
            }
        }

    def on_test_epoch_start(self) -> None:
        print('')

    def test_step(self, batch, batch_idx):
        eeg, label, image_features = batch
        output = self.model(eeg)
        ankor = self.trainer.datamodule.test_dataset.image_feat_cls
        top1, top5, rank_acc = self.rank_acc(output, ankor, label)
        self.log('test_top1', top1)
        self.log('test_top5', top5)
        self.log('test_rank_acc', rank_acc)


    def on_test_epoch_end(self) -> None:
        test_top1 = self.trainer.callback_metrics['test_top1'].item()
        test_top5 = self.trainer.callback_metrics['test_top5'].item()
        test_rank_acc = self.trainer.callback_metrics['test_rank_acc'].item()
        best_epoch = self.trainer.checkpoint_callback.best_model_path
        self.txtlogger.info('------'*10)
        self.txtlogger.info(f"Best model: {best_epoch}")
        self.txtlogger.info(f"Test on best model, "
                            f"test_top1: {test_top1}, "
                            f"test_top5: {test_top5}, "
                            f"test_rank_acc: {test_rank_acc}")

def setup_logging(log_dir, fold):
    # check if the log directory exists
    if not op.exists(log_dir):
        os.makedirs(log_dir)
    '''setting up logger file'''
    file_handler = logging.FileHandler(op.join(log_dir, f'acc_log_{fold}.txt'))
    formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
    file_handler.setFormatter(formatter)
    # add the file handler to the logger
    logger = logging.getLogger(f'acc_log_{fold}')
    logger.setLevel(logging.INFO)
    logger.addHandler(file_handler)
    return logger
# Press the green button in the gutter to run the script.
if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument("--max_epochs", type=int, default=1000)
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--log_dir", type=str, default='logs/main_reg2')
    parser.add_argument("--dataset", type=str, default='WM')
    parser.add_argument("--temp", type=float, default=0.07)
    args = parser.parse_args()
    current_time = datetime.datetime.now().strftime("%m-%d_%H-%M")
    fold_dir = op.join(args.log_dir, current_time)
    for sub in range(0, 10):
        fold_dir_sub = op.join(fold_dir, f'sub_{sub}')
        args.log_dir = fold_dir_sub
        checkpoint_callback = ModelCheckpoint(
            monitor='val_top1',
            mode='max',
            dirpath= fold_dir_sub,
            filename='model-{epoch:02d}-{val_top1:.4f}',
            save_top_k=1,
            save_last=True,
            verbose=True,
        )
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        dataset = EEGDataModuleBalanced(dataset=args.dataset,
                                subject=sub,
                                batch_size=args.batch_size)
        logger = setup_logging(args.log_dir, sub)
        # write args to log
        args_json = json.dumps(vars(args), indent=4)
        logger.info(args_json)
        logger.info(f"subject: {sub}")
        logger.info(f"dataset: {args.dataset}")
        model = EEG2EMB_CTS(args=args, logger=logger)
        logger.info(model)
        trainer = pl.Trainer(max_epochs=args.max_epochs, callbacks=[checkpoint_callback], gpus=1)
        trainer.fit(model, dataset)
        # load best model
        best_model_path = checkpoint_callback.best_model_path
        checkpoint = torch.load(best_model_path)
        model.load_state_dict(checkpoint['state_dict'])
        trainer.test(model, dataset)

